{
 "cells": [
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "# Batch Predictions\n",
    "https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker_batch_transform/tensorflow_cifar-10_with_inference_script/tensorflow-serving-cifar10-python-sdk.ipynb\n",
    "\n",
    "and \n",
    "\n",
    "https://aws.amazon.com/blogs/machine-learning/performing-batch-inference-with-tensorflow-serving-in-amazon-sagemaker/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install --upgrade sagemaker==1.56.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Input Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r scikit_processing_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(scikit_processing_job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# scikit_processing_job_s3_output_prefix = 'data'\n",
    "print('Previous Scikit Processing Job Name: {}'.format(scikit_processing_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix_test = '{}/output/bert-test'.format(scikit_processing_job_name)\n",
    "\n",
    "test_s3_uri = 's3://{}/{}'.format(bucket, prefix_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(test_s3_uri)\n",
    "\n",
    "!aws s3 ls $test_s3_uri/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Batch Transform Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(training_job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!saved_model_cli show --all --dir ./tensorflow/saved_model/0/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow.serving import Model\n",
    "\n",
    "# If you change SAGEMAKER_TFS_DEFAULT_MODEL_NAME to something other than 'saved_model', you may see the dreaded ping error in the logs error\n",
    "batch_env = {\n",
    "  'SAGEMAKER_TFS_DEFAULT_MODEL_NAME': 'saved_model', # <== change this when using multi-model,\n",
    "                                                     #     but watch out for the dreaded ping/ error \n",
    "                                                     #     if the model name doesn't exist\n",
    "  'SAGEMAKER_TFS_ENABLE_BATCHING': 'true',\n",
    "  'SAGEMAKER_TFS_BATCH_TIMEOUT_MICROS': '50000',\n",
    "  'SAGEMAKER_TFS_MAX_BATCH_SIZE': '16'\n",
    "}\n",
    "\n",
    "batch_model = Model(entry_point='inference.py', # <== This cannot change. See https://github.com/aws/sagemaker-python-sdk/issues/1451\n",
    "                    source_dir='src_batch_tfrecord',\n",
    "                    model_data='s3://{}/{}/output/model.tar.gz'.format(bucket, training_job_name),\n",
    "                    role=role,\n",
    "                    framework_version='2.1.0',\n",
    "                    env=batch_env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "batch_predictor = batch_model.transformer(instance_count=1, \n",
    "                                          strategy='MultiRecord', \n",
    "                                          instance_type='ml.c5.9xlarge',\n",
    "                                          assemble_with='Line',\n",
    "                                          max_payload=100, # This is in Megabytes (not number of records)\n",
    "                                          env=batch_env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start Batch Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_predictor.transform(data=test_s3_uri, \n",
    "                          split_type='TFRecord',\n",
    "                          compression_type=None,\n",
    "                          experiment_config=None,\n",
    "                          wait=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/transform-jobs/{}?region={}&tab=Monitor\">Batch Prediction Job</a></b>'.format(region, batch_predictor.latest_transform_job.job_name, region)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TransformJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a></b>'.format(region, batch_predictor.latest_transform_job.job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/s3/buckets/{}/{}/?region={}\">Batch Prediction S3 Output</a></b>'.format(bucket, batch_predictor.latest_transform_job.job_name, region)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Waiting for batch prediction job: ' + batch_predictor.latest_transform_job.job_name)\n",
    "\n",
    "batch_predictor.wait(logs=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Output Data\n",
    "\n",
    "After the transform job has completed, download the output data from S3.\n",
    "\n",
    "For each file \"f\" in the input data, we have a corresponding file \"f.out\" containing the predicted labels from each input row. \n",
    "\n",
    "We can compare the predicted labels to the true labels saved earlier.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the output data from S3 to local filesystem\n",
    "batch_prediction_output_s3_uri = batch_predictor.output_path\n",
    "\n",
    "# !mkdir -p ./batch_prediction_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash \n",
    "\n",
    "aws s3 cp --recursive $batch_prediction_output_s3_uri/ ./batch_prediction_output\n",
    "\n",
    "ls ./batch_prediction_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
